{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 微调和部署对话模型ChatGLM2-6B\n",
    "\n",
    "[ChatGLM2-6B](https://www.modelscope.cn/models/ZhipuAI/chatglm2-6b/summary)是中英文对话模型[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本，在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上，ChatGLM2-6B 引入了多项升级，包括更强大的性能、更长的上下文、更高效的推理。\n",
    "\n",
    "在本示例中，我们将展示：\n",
    "\n",
    "- 将ChatGLM2-6B部署到PAI创建推理服务，基于推理服务API和Gradio实现一个简易对话机器人。\n",
    "\n",
    "- 在PAI对ChatGLM2-6B进行微调训练，并将微调的模型部署创建推理服务。\n",
    "\n",
    "\n",
    "## 准备工作\n",
    "\n",
    "### 前提条件\n",
    "\n",
    "- 已获取阿里云账号的鉴权AccessKey ID和AccessKey Secret，详情请参见：[获取AccessKey](https://help.aliyun.com/document_detail/116401.html)。\n",
    "- 已创建或是加入一个PAI AI工作空间，详情请参见：[创建工作空间](https://help.aliyun.com/document_detail/326193.html)。\n",
    "- 已创建OSS Bucket，详情请参见：[控制台创建存储空间](https://help.aliyun.com/document_detail/31885.html)。\n",
    "\n",
    "\n",
    "### 安装和配置PAI Python SDK\n",
    "\n",
    "我们将使用PAI提供的Python SDK，提交训练作业，部署模型。可以通过以下命令安装PAI Python SDK。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade alipai\n",
    "!python -m pip install gradio"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "SDK需要配置访问阿里云服务需要的 AccessKey，以及当前使用的工作空间和OSS Bucket。在PAI Python SDK安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥，工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "我们可以通过以下代码验证当前的配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "sess = get_default_session()\n",
    "\n",
    "assert sess.workspace_name is not None\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 直接部署ChatGLM2\n",
    "\n",
    "`ChatGLM2-6B`是一个对话语言模型，能够基于历史对话信息，和用户的Prompt输入，进行反馈。通过HuggingFace的transformers库用户可以直接使用`ChatGLM2-6B`提供的对话能力，示例如下:\n",
    "\n",
    "```python\n",
    "\n",
    ">>> from transformers import AutoTokenizer, AutoModel\n",
    ">>> tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm2-6b\", trust_remote_code=True)\n",
    ">>> model = AutoModel.from_pretrained(\"THUDM/chatglm2-6b\", trust_remote_code=True).half().cuda()\n",
    ">>> model = model.eval()\n",
    ">>> response, history = model.chat(tokenizer, \"你好\", history=[])\n",
    ">>> print(response)\n",
    "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。\n",
    ">>> response, history = model.chat(tokenizer, \"晚上睡不着应该怎么办\", history=history)\n",
    ">>> print(response)\n",
    "晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:\n",
    "\n",
    "1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。\n",
    "2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。\n",
    "3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。\n",
    "4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。\n",
    "5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。\n",
    "6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。\n",
    "\n",
    "如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。\n",
    "\n",
    "\n",
    "\n",
    "```\n",
    "\n",
    "以下的流程中，我们将`ChatGLM2-6B`部署到PAI创建一个推理服务，然后基于推理服务的API，使用Gradio创建一个对话机器人。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### 获取ChatGLM2模型\n",
    "\n",
    "推理服务和训练作业中都需要加载使用模型，PAI在部分region上提供模型缓存，支持用户能够更快地获取到相应的模型。用户可以通过以下代码获取相应的模型，然后在训练作业和推理服务中加载使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import RegisteredModel\n",
    "\n",
    "m = RegisteredModel(\n",
    "    \"THUDM/chatglm2-6b\",\n",
    "    model_provider=\"huggingface\",\n",
    ")\n",
    "\n",
    "model_uri = m.model_data\n",
    "print(model_uri)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建推理服务\n",
    "\n",
    "PAI-EAS是阿里云PAI提供模型在线服务平台，支持用户一键部署推理服务或是AIWeb应用，支持异构资源，弹性扩缩容。PAI-EAS支持使用镜像的方式部署模型，以下的流程，我们将使用PAI提供的PyTorch推理镜像，将以上的模型部署为推理服务。\n",
    "\n",
    "\n",
    "在部署推理服务之前，我们需要准备相应的推理服务程序，他负责加载模型，提供对应的HTTP API服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p server_src"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完整的推理服务代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile server_src/run.py\n",
    "# source: https://github.com/THUDM/ChatGLM-6B/blob/main/api.py\n",
    "\n",
    "import os\n",
    "\n",
    "from fastapi import FastAPI, Request\n",
    "from transformers import AutoTokenizer, AutoModel, AutoConfig\n",
    "import uvicorn, json, datetime\n",
    "import torch\n",
    "\n",
    "\n",
    "model = None\n",
    "tokenizer = None\n",
    "\n",
    "# 默认的模型保存路径\n",
    "chatglm_model_path = \"/eas/workspace/model/\"\n",
    "# ptuning checkpoints保存路径\n",
    "ptuning_checkpoint = \"/ml/ptuning_checkpoints/\"\n",
    "pre_seq_len = 128\n",
    "app = FastAPI()\n",
    "\n",
    "\n",
    "def load_model():\n",
    "    global model, tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(chatglm_model_path, trust_remote_code=True)\n",
    "\n",
    "    if os.path.exists(ptuning_checkpoint):\n",
    "        # P-tuning v2\n",
    "        print(f\"Loading model/ptuning_checkpoint weight...\")\n",
    "        config = AutoConfig.from_pretrained(chatglm_model_path, trust_remote_code=True)\n",
    "        config.pre_seq_len = pre_seq_len\n",
    "        config.prefix_projection = False\n",
    "\n",
    "        model = AutoModel.from_pretrained(chatglm_model_path, config=config, trust_remote_code=True)\n",
    "        tokenizer = AutoTokenizer.from_pretrained(chatglm_model_path, trust_remote_code=True)\n",
    "        prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, \"pytorch_model.bin\"))\n",
    "        new_prefix_state_dict = {}\n",
    "        for k, v in prefix_state_dict.items():\n",
    "            if k.startswith(\"transformer.prefix_encoder.\"):\n",
    "                new_prefix_state_dict[k[len(\"transformer.prefix_encoder.\"):]] = v\n",
    "        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)\n",
    "\n",
    "        model = model.half().cuda()\n",
    "        model.transformer.prefix_encoder.float().cuda()\n",
    "        model.eval()\n",
    "    else:\n",
    "        print(f\"Loading model weight...\")\n",
    "        model = AutoModel.from_pretrained(chatglm_model_path, trust_remote_code=True)\n",
    "        model.half().cuda()\n",
    "        model.eval()\n",
    "\n",
    "\n",
    "\n",
    "@app.post(\"/\")\n",
    "async def create_item(request: Request):\n",
    "    global model, tokenizer\n",
    "    json_post_raw = await request.json()\n",
    "    json_post = json.dumps(json_post_raw)\n",
    "    json_post_list = json.loads(json_post)\n",
    "    prompt = json_post_list.get('prompt')\n",
    "    history = json_post_list.get('history')\n",
    "    max_length = json_post_list.get('max_length')\n",
    "    top_p = json_post_list.get('top_p')\n",
    "    temperature = json_post_list.get('temperature')\n",
    "    response, history = model.chat(tokenizer,\n",
    "                                   prompt,\n",
    "                                   history=history,\n",
    "                                   max_length=max_length if max_length else 2048,\n",
    "                                   top_p=top_p if top_p else 0.7,\n",
    "                                   temperature=temperature if temperature else 0.95)\n",
    "    now = datetime.datetime.now()\n",
    "    time = now.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    answer = {\n",
    "        \"response\": response,\n",
    "        \"history\": history,\n",
    "        \"status\": 200,\n",
    "        \"time\": time\n",
    "    }\n",
    "    log = \"[\" + time + \"] \" + '\", prompt:\"' + prompt + '\", response:\"' + repr(response) + '\"'\n",
    "    print(log)\n",
    "    return answer\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    load_model()\n",
    "    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将使用PyTorch镜像运行相应的推理服务，在启动服务之前需要安装模型依赖的相关的依赖。我们可以在`server_src`下准备依赖的`requirements.txt`，对应的`requirements.txt`会在推理服务启动之前被安装到环境中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile server_src/requirements.txt\n",
    "\n",
    "# 模型需要的依赖\n",
    "transformers==4.30.2\n",
    "accelerate\n",
    "icetk\n",
    "cpm_kernels\n",
    "\n",
    "torch>=2.0,<2.1\n",
    "gradio\n",
    "mdtex2html\n",
    "sentencepiece\n",
    "accelerate\n",
    "\n",
    "# 推理服务Server的依赖\n",
    "fastapi\n",
    "uvicorn"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "基于以上的推理服务程序，我们将使用PyTorch镜像和OSS上的模型在PAI创建一个推理服务，代码如下。\n",
    "\n",
    "> 对于如何使用SDK创建推理服务的详细介绍，请见文档：[创建推理服务](https://help.aliyun.com/document_detail/2261532.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import container_serving_spec, Model\n",
    "from pai.image import retrieve, ImageScope\n",
    "from pai.common.utils import random_str\n",
    "\n",
    "\n",
    "# InferenceSpec用于描述如何创建推理服务\n",
    "infer_spec = container_serving_spec(\n",
    "    # 使用PAI提供的最新PyTorch的推理镜像\n",
    "    image_uri=retrieve(\n",
    "        \"PyTorch\",\n",
    "        \"latest\",\n",
    "        accelerator_type=\"GPU\",\n",
    "        image_scope=ImageScope.INFERENCE,\n",
    "    ),\n",
    "    source_dir=\"./server_src\",\n",
    "    command=\"python run.py\",\n",
    ")\n",
    "\n",
    "m = Model(\n",
    "    # 模型的OSS路径，默认模型会通过挂载的方式挂载到`/eas/workspace/model/`路径下。\n",
    "    model_data=model_uri,\n",
    "    inference_spec=infer_spec,\n",
    ")\n",
    "\n",
    "\n",
    "# 部署模型，创建推理服务.\n",
    "p = m.deploy(\n",
    "    service_name=\"chatglm_demo_{}\".format(random_str(6)),\n",
    "    instance_type=\"ecs.gn6i-c8g1.2xlarge\",  # 8vCPU 31GB NVIDIA T4×1(GPU Mem 16GB)\n",
    "    options={\n",
    "        # 配置EAS RPC框架的超时时间, 单位为毫秒\n",
    "        \"metadata.rpc.keepalive\": 20000,\n",
    "    },\n",
    ")\n",
    "\n",
    "print(p.service_name)\n",
    "print(p.service_status)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`m.deploy`返回一个Predictor对象，可以用于向创建的推理服务程序发送预测请求。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.predictor import RawResponse\n",
    "\n",
    "resp: RawResponse = p.raw_predict(\n",
    "    {\n",
    "        \"prompt\": \"你好\",\n",
    "    }\n",
    ")\n",
    "print(resp.json()[\"response\"])\n",
    "\n",
    "\n",
    "resp = p.raw_predict(\n",
    "    {\n",
    "        \"prompt\": \"晚上睡不着应该怎么办\",\n",
    "        \"history\": resp.json()[\"history\"],\n",
    "    },\n",
    "    timeout=20,\n",
    ")\n",
    "print(resp.json())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于以上的推理服务，我们可以使用Gradio创建一个简单的对话机器人demo。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import random\n",
    "import time\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot()\n",
    "    msg = gr.Textbox()\n",
    "    clear = gr.Button(\"Clear\")\n",
    "    submit = gr.Button(\"Submit\")\n",
    "\n",
    "    def respond(message, chat_history):\n",
    "\n",
    "        print(f\"Message: {message}\")\n",
    "        print(f\"ChatHistory: {chat_history}\")\n",
    "        resp = p.raw_predict(\n",
    "            {\n",
    "                \"prompt\": message,\n",
    "                \"history\": chat_history,\n",
    "            }\n",
    "        ).json()\n",
    "        print(f\"Response: {resp['response']}\")\n",
    "\n",
    "        chat_history.append((message, resp[\"response\"]))\n",
    "        return \"\", chat_history\n",
    "\n",
    "    submit.click(respond, [msg, chatbot], [msg, chatbot])\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过以上创建的Gradio应用，我们可以在页面上与部署的ChatGLM模型进行对话。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "在测试完成之后，我们可以通过以下的代码删除推理服务，释放资源。\n",
    "\n",
    "> 请注意，删除在线推理服务之后，对应的Gradio的应用将无法使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p.delete_service()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 微调ChatGLM2-6B\n",
    "\n",
    "我们可以使用领域数据对ChatGLM进行微调，从而使得模型在特定领域和任务下有更好的表现。ChatGLM团队提供了使用[P-Tuning v2](https://github.com/THUDM/P-tuning-v2)方式对模型进行[微调的方案](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning)，我们将基于此方案展示如何将微调训练作业提交到PAI的训练服务执行。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备训练数据集\n",
    "\n",
    "我们将使用了[广告生成数据集](https://aclanthology.org/D19-1321.pdf)，对ChatGLM进行微调。我们首先需要准备数据到OSS，供后续微调训练作业使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.oss_utils import download, OssUriObj, upload\n",
    "import zipfile\n",
    "\n",
    "# 下载数据\n",
    "data = download(\n",
    "    # 当前的数据集在上海region，跨region下载，我们需要传递对应OSS Bucket所在Endpoint.\n",
    "    OssUriObj(\n",
    "        \"oss://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/chatGLM/AdvertiseGen_Simple.zip\"\n",
    "    ),\n",
    "    local_path=\"./\",\n",
    ")\n",
    "\n",
    "# 解压缩数据\n",
    "with zipfile.ZipFile(data, \"r\") as zip_ref:\n",
    "    zip_ref.extractall(\"./train_data/\")\n",
    "\n",
    "# 上传数据到OSS\n",
    "train_data = \"./train_data/AdvertiseGen_Simple/\"\n",
    "train_data_uri = upload(\n",
    "    \"./train_data/AdvertiseGen_Simple/\", oss_path=\"chatglm_demo/data/advertisegen/\"\n",
    ")\n",
    "print(train_data_uri)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相应的数据集数据格式如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!head -n 5 ./train_data/AdvertiseGen_Simple/train.json"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备微调训练作业脚本\n",
    "\n",
    "ChatGLM的官方提供[微调训练脚本](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning)，支持使用P-Tuning v2的方式对ChatGLM模型进行微调。我们将基于相应的微调训练脚本，修改训练作业的拉起Shell脚本(`train.sh`)，然后使用PAI Python SDK将微调训练作业提交到PAI执行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 下载ChatGLM代码\n",
    "!git clone https://github.com/THUDM/ChatGLM2-6B.git"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "\n",
    "当训练作业提交到PAI执行时，需要按一定规范读取输入数据，以及将需要保存的模型写出到指定路径下，更加具体介绍请见文档：[提交训练作业](https://help.aliyun.com/document_detail/2261505.html)。\n",
    "\n",
    "修改后的训练作业拉起脚本如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "%%writefile ChatGLM2-6B/ptuning/train.sh\n",
    "\n",
    "PRE_SEQ_LEN=128\n",
    "LR=2e-2\n",
    "NUM_GPUS=`nvidia-smi --list-gpus | wc -l`\n",
    "\n",
    "torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \\\n",
    "    --do_train \\\n",
    "    --train_file /ml/input/data/train/train.json \\\n",
    "    --validation_file /ml/input/data/train/dev.json \\\n",
    "    --preprocessing_num_workers 10 \\\n",
    "    --prompt_column content \\\n",
    "    --response_column summary \\\n",
    "    --overwrite_cache \\\n",
    "    --model_name_or_path /ml/input/data/model \\\n",
    "    --output_dir /ml/output/model/ \\\n",
    "    --overwrite_output_dir \\\n",
    "    --max_source_length 64 \\\n",
    "    --max_target_length 128 \\\n",
    "    --per_device_train_batch_size 4 \\\n",
    "    --per_device_eval_batch_size 4 \\\n",
    "    --gradient_accumulation_steps 32 \\\n",
    "    --predict_with_generate \\\n",
    "    --num_train_epochs 10 \\\n",
    "    --save_strategy epoch \\\n",
    "    --learning_rate $LR \\\n",
    "    --pre_seq_len $PRE_SEQ_LEN \\\n",
    "    --quantization_bit 4\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里我们将使用PAI提供的PyTorch GPU训练镜像运行训练作业，需要安装部分第三方依赖包。用户可以通过提供`requirements.txt`的方式提供，相应的依赖会在训练作业执行前被安装到环境中\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "%%writefile ChatGLM2-6B/ptuning/requirements.txt\n",
    "# 模型需要的依赖\n",
    "transformers==4.30.2\n",
    "accelerate\n",
    "icetk\n",
    "cpm_kernels\n",
    "\n",
    "torch>=2.0,<2.1\n",
    "sentencepiece\n",
    "accelerate\n",
    "\n",
    "rouge_chinese\n",
    "nltk\n",
    "jieba\n",
    "datasets"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 提交训练作业\n",
    "\n",
    "我们将通过PAI Python SDK，将以上的训练作业提交到PAI执行。SDK在提交训练作业之后，会打印训练作业的链接，用户可以通过对应的链接查看作业的执行详情，输出日志。\n",
    "\n",
    "> Note：按当前示例教程使用的训练配置、数据集和机器规格，训练作业运行约10分钟左右。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "# 使用PAI提供的最新的PyTorch推理镜像\n",
    "image_uri = retrieve(\n",
    "    \"PyTorch\",\n",
    "    \"latest\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "\n",
    "\n",
    "est = Estimator(\n",
    "    command=\"bash train.sh\",  # 启动命令\n",
    "    source_dir=\"./ChatGLM2-6B/ptuning\",  # 训练代码目录.\n",
    "    image_uri=image_uri,  # 训练镜像\n",
    "    instance_type=\"ecs.gn6e-c12g1.3xlarge\",  # 使用的机器规格示例，V100(32G)\n",
    "    base_job_name=\"chatglm2_finetune_\",\n",
    ")\n",
    "\n",
    "\n",
    "# 提交训练作业\n",
    "est.fit(\n",
    "    inputs={\n",
    "        \"model\": model_uri,\n",
    "        \"train\": train_data_uri,\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "默认`estimator.fit`会等待到作业执行完成。作业执行成功之后，用户可以通过`est.model_data()`获取输出模型在OSS上的路径地址。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(est.model_data())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用户可以通过`ossutil`或是SDK提供的便利方法将模型下载到本地:\n",
    "\n",
    "```python\n",
    "from pai.common.oss_util import download\n",
    "\n",
    "\n",
    "# 使用SDK的便利方法下载模型到本地.\n",
    "download(\n",
    "\toss_path=est.model_data(),\n",
    "\tlocal_path=\"./output_model\",\n",
    ")\n",
    "\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 部署微调之后的模型\n",
    "\n",
    "微调训练之后获得的`checkpoints`，需要和原始的模型配合一起使用。我们需要通过以下代码获得对应的checkpoint路径.\n",
    "\n",
    "> 用户通过修改微调训练的代码，使用`Trainer.save_model()`显式的保存相应的checkpoints，则可以直接通过`estimator.model_data()`下获得相应的checkpoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# 以上的训练作业超参设置中，我们设置`epochs=2`, checkpoints保存的策略是`每一个epochs保存`。\n",
    "# 默认最后一个checkpoint会被保存到`{output_dir}/checkpoint-2`路径下.\n",
    "# 通过以下路径，我们可以获得模型训练获得的最后一个checkpoint的OSS路径.\n",
    "\n",
    "checkpoint_uri = os.path.join(est.model_data(), \"checkpoint-10/\")\n",
    "print(checkpoint_uri)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将复用ChatGLM2部署的推理服务程序创建推理服务。与直接部署ChatGLM2的不同点在于我们还需要提供微调之后获得的checkpoints。\n",
    "\n",
    "通过`InferenceSpec.mount` API，我们可以将相应的OSS模型路径挂载到服务容器中，供推理服务程序使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pai.model import container_serving_spec, Model\n",
    "from pai.image import retrieve, ImageScope\n",
    "\n",
    "\n",
    "# InferenceSpec用于描述如何创建推理服务\n",
    "infer_spec = container_serving_spec(\n",
    "    image_uri=retrieve(  # 使用PAI提供的最新PyTorch的推理镜像\n",
    "        \"PyTorch\",\n",
    "        \"latest\",\n",
    "        accelerator_type=\"GPU\",\n",
    "        image_scope=ImageScope.INFERENCE,\n",
    "    ),\n",
    "    source_dir=\"./server_src\",  # 代码目录\n",
    "    command=\"python run.py\",  # 启动命令\n",
    ")\n",
    "\n",
    "\n",
    "# 将相应的checkpoints挂载到服务中，推理服务的程序通过检查目录(/ml/ptuning_checkpoints/)是否存在加载checkpoints\n",
    "infer_spec.mount(checkpoint_uri, \"/ml/ptuning_checkpoints\")\n",
    "print(infer_spec.to_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.utils import random_str\n",
    "\n",
    "m = Model(\n",
    "    model_data=model_uri,\n",
    "    inference_spec=infer_spec,\n",
    ")\n",
    "\n",
    "# 部署模型\n",
    "p = m.deploy(\n",
    "    service_name=\"chatglm_ft_{}\".format(random_str(6)),\n",
    "    instance_type=\"ecs.gn6i-c16g1.4xlarge\",  # 1 * T4\n",
    "    options={\n",
    "        # 配置EAS RPC框架的超时时间, 单位为毫秒\n",
    "        \"metadata.rpc.keepalive\": 20000,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "向推理服务发送请求，测试推理服务是否正常启动。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resp = p.raw_predict(\n",
    "    {\n",
    "        \"prompt\": \"你好\",\n",
    "    },\n",
    ")\n",
    "print(resp.json())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于以上微调后模型的推理服务，我们可以使用Gradio创建一个新的机器人。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import random\n",
    "import time\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot()\n",
    "    msg = gr.Textbox()\n",
    "    clear = gr.Button(\"Clear\")\n",
    "    submit = gr.Button(\"Submit\")\n",
    "\n",
    "    def respond(message, chat_history):\n",
    "\n",
    "        print(f\"Message: {message}\")\n",
    "        print(f\"ChatHistory: {chat_history}\")\n",
    "        resp = p.raw_predict(\n",
    "            {\n",
    "                \"prompt\": message,\n",
    "                \"history\": chat_history,\n",
    "            }\n",
    "        ).json()\n",
    "        print(f\"Response: {resp['response']}\")\n",
    "\n",
    "        chat_history.append((message, resp[\"response\"]))\n",
    "        return \"\", chat_history\n",
    "\n",
    "    submit.click(respond, [msg, chatbot], [msg, chatbot])\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在测试完成之后，可以通过`p.delete_service()`删除服务，释放资源。\n",
    "\n",
    "> 请注意，删除在线推理服务之后，对应的Gradio的应用将无法使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p.delete_service()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
